{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The NEF algorithm\n",
    "\n",
    "While Nengo provides a flexible,\n",
    "general-purpose approach to neural modelling,\n",
    "some aspects rely on the Neural Engineering Framework (NEF)\n",
    "to help specify network behavior.\n",
    "The theory behind the Neural Engineering Framework\n",
    "is developed at length in\n",
    "[Eliasmith & Anderson, 2003: \"Neural Engineering\"](\n",
    "https://mitpress.mit.edu/books/neural-engineering)\n",
    "and a short summary is available in\n",
    "[Stewart, 2012: \"A Technical Overview of the Neural Engineering Framework\"](\n",
    "http://compneuro.uwaterloo.ca/publications/stewart2012d.html).\n",
    "\n",
    "However, for some people,\n",
    "the best description of an algorithm\n",
    "is the code itself.\n",
    "With that in mind,\n",
    "the following is a complete implementation of the NEF\n",
    "for the special case of\n",
    "two one-dimensional populations with a single connection between them.\n",
    "You can adjust the function being computed,\n",
    "the input to the system,\n",
    "and various neural parameters.\n",
    "\n",
    "This example does not use Nengo at all.\n",
    "It is Python code that only requires\n",
    "Numpy (for the matrix inversion)\n",
    "and Matplotlib (to produce graphs of the output)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "The NEF is a method for building large-scale neural models\n",
    "using realistic neurons.\n",
    "It is a neural compiler:\n",
    "you specify the high-level computations the model needs to compute,\n",
    "and the properties of the neurons themselves,\n",
    "and the NEF determines the neural connections\n",
    "needed to perform those operations.\n",
    "\n",
    "This script shows how to build a\n",
    "simple feed-forward network of leaky integrate-and-fire neurons\n",
    "where each population encodes a one-dimensional value\n",
    "and the connection weights between the populations\n",
    "are optimized to compute some arbitrary function.\n",
    "This same approach is used in Nengo,\n",
    "extended to multi-dimensional representation,\n",
    "multiple populations of neurons, and recurrent connections."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To change the input to the system, change `input`.\n",
    "To change the function computed by the weights, change `function`.\n",
    "\n",
    "The size of the populations and their neural properties\n",
    "can also be adjusted by changing the parameters below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import math\n",
    "import random\n",
    "\n",
    "import numpy\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "dt = 0.001  # simulation time step\n",
    "t_rc = 0.02  # membrane RC time constant\n",
    "t_ref = 0.002  # refractory period\n",
    "t_pstc = 0.1  # post-synaptic time constant\n",
    "N_A = 50  # number of neurons in first population\n",
    "N_B = 40  # number of neurons in second population\n",
    "N_samples = 100  # number of sample points to use when finding decoders\n",
    "rate_A = 25, 75  # range of maximum firing rates for population A\n",
    "rate_B = 50, 100  # range of maximum firing rates for population B\n",
    "\n",
    "\n",
    "def input(t):\n",
    "    \"\"\"The input to the system over time\"\"\"\n",
    "    return math.sin(t)\n",
    "\n",
    "\n",
    "def function(x):\n",
    "    \"\"\"The function to compute between A and B.\"\"\"\n",
    "    return x * x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Initialization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create random encoders for the two populations\n",
    "encoder_A = [random.choice([-1, 1]) for i in range(N_A)]\n",
    "encoder_B = [random.choice([-1, 1]) for i in range(N_B)]\n",
    "\n",
    "\n",
    "def generate_gain_and_bias(\n",
    "        count, intercept_low, intercept_high, rate_low, rate_high):\n",
    "    gain = []\n",
    "    bias = []\n",
    "    for i in range(count):\n",
    "        # desired intercept (x value for which the neuron starts firing\n",
    "        intercept = random.uniform(intercept_low, intercept_high)\n",
    "        # desired maximum rate (firing rate when x is maximum)\n",
    "        rate = random.uniform(rate_low, rate_high)\n",
    "\n",
    "        # this algorithm is specific to LIF neurons, but should\n",
    "        # generate gain and bias values to produce the desired\n",
    "        # intercept and rate\n",
    "        z = 1.0 / (1 - math.exp((t_ref - (1.0 / rate)) / t_rc))\n",
    "        g = (1 - z) / (intercept - 1.0)\n",
    "        b = 1 - g * intercept\n",
    "        gain.append(g)\n",
    "        bias.append(b)\n",
    "    return gain, bias\n",
    "\n",
    "\n",
    "# random gain and bias for the two populations\n",
    "gain_A, bias_A = generate_gain_and_bias(N_A, -1, 1, rate_A[0], rate_A[1])\n",
    "gain_B, bias_B = generate_gain_and_bias(N_B, -1, 1, rate_B[0], rate_B[1])\n",
    "\n",
    "\n",
    "def run_neurons(input, v, ref):\n",
    "    \"\"\"Run the neuron model.\n",
    "\n",
    "    A simple leaky integrate-and-fire model, scaled so that v=0 is resting\n",
    "    voltage and v=1 is the firing threshold.\n",
    "    \"\"\"\n",
    "    spikes = []\n",
    "    for i, _ in enumerate(v):\n",
    "        dV = dt * (input[i] - v[i]) / t_rc  # the LIF voltage change equation\n",
    "        v[i] += dV\n",
    "        if v[i] < 0:\n",
    "            v[i] = 0  # don't allow voltage to go below 0\n",
    "\n",
    "        if ref[i] > 0:  # if we are in our refractory period\n",
    "            v[i] = 0   # keep voltage at zero and\n",
    "            ref[i] -= dt  # decrease the refractory period\n",
    "\n",
    "        if v[i] > 1:  # if we have hit threshold\n",
    "            spikes.append(True)  # spike\n",
    "            v[i] = 0  # reset the voltage\n",
    "            ref[i] = t_ref  # and set the refractory period\n",
    "        else:\n",
    "            spikes.append(False)\n",
    "    return spikes\n",
    "\n",
    "\n",
    "def compute_response(x, encoder, gain, bias, time_limit=0.5):\n",
    "    \"\"\"Measure the spike rate of a population for a given value x.\"\"\"\n",
    "    N = len(encoder)  # number of neurons\n",
    "    v = [0] * N  # voltage\n",
    "    ref = [0] * N  # refractory period\n",
    "\n",
    "    # compute input corresponding to x\n",
    "    input = []\n",
    "    for i in range(N):\n",
    "        input.append(x * encoder[i] * gain[i] + bias[i])\n",
    "        v[i] = random.uniform(0, 1)  # randomize the initial voltage level\n",
    "\n",
    "    count = [0] * N  # spike count for each neuron\n",
    "\n",
    "    # feed the input into the population for a given amount of time\n",
    "    t = 0\n",
    "    while t < time_limit:\n",
    "        spikes = run_neurons(input, v, ref)\n",
    "        for i, s in enumerate(spikes):\n",
    "            if s:\n",
    "                count[i] += 1\n",
    "        t += dt\n",
    "    return [c / time_limit for c in count]  # return the spike rate (in Hz)\n",
    "\n",
    "\n",
    "def compute_tuning_curves(encoder, gain, bias):\n",
    "    \"\"\"Compute the tuning curves for a population\"\"\"\n",
    "    # generate a set of x values to sample at\n",
    "    x_values = [i * 2.0 / N_samples - 1.0 for i in range(N_samples)]\n",
    "\n",
    "    # build up a matrix of neural responses to each input (i.e. tuning curves)\n",
    "    A = []\n",
    "    for x in x_values:\n",
    "        response = compute_response(x, encoder, gain, bias)\n",
    "        A.append(response)\n",
    "    return x_values, A\n",
    "\n",
    "\n",
    "def compute_decoder(encoder, gain, bias, function=lambda x: x):\n",
    "    # get the tuning curves\n",
    "    x_values, A = compute_tuning_curves(encoder, gain, bias)\n",
    "\n",
    "    # get the desired decoded value for each sample point\n",
    "    value = numpy.array([[function(x)] for x in x_values])\n",
    "\n",
    "    # find the optimal linear decoder\n",
    "    A = numpy.array(A).T\n",
    "    Gamma = numpy.dot(A, A.T)\n",
    "    Upsilon = numpy.dot(A, value)\n",
    "    Ginv = numpy.linalg.pinv(Gamma)\n",
    "    decoder = numpy.dot(Ginv, Upsilon) / dt\n",
    "    return decoder\n",
    "\n",
    "\n",
    "# find the decoders for A and B\n",
    "decoder_A = compute_decoder(encoder_A, gain_A, bias_A, function=function)\n",
    "decoder_B = compute_decoder(encoder_B, gain_B, bias_B)\n",
    "\n",
    "# compute the weight matrix\n",
    "weights = numpy.dot(decoder_A, [encoder_B])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Running the simulation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "v_A = [0.0] * N_A  # voltage for population A\n",
    "ref_A = [0.0] * N_A  # refractory period for population A\n",
    "input_A = [0.0] * N_A  # input for population A\n",
    "\n",
    "v_B = [0.0] * N_B  # voltage for population B\n",
    "ref_B = [0.0] * N_B  # refractory period for population B\n",
    "input_B = [0.0] * N_B  # input for population B\n",
    "\n",
    "# scaling factor for the post-synaptic filter\n",
    "pstc_scale = 1.0 - math.exp(-dt / t_pstc)\n",
    "\n",
    "# for storing simulation data to plot afterward\n",
    "inputs = []\n",
    "times = []\n",
    "outputs = []\n",
    "ideal = []\n",
    "\n",
    "output = 0.0  # the decoded output value from population B\n",
    "t = 0\n",
    "while t < 10.0:  # noqa: C901 (tell static checker to ignore complexity)\n",
    "    # call the input function to determine the input value\n",
    "    x = input(t)\n",
    "\n",
    "    # convert the input value into an input for each neuron\n",
    "    for i in range(N_A):\n",
    "        input_A[i] = x * encoder_A[i] * gain_A[i] + bias_A[i]\n",
    "\n",
    "    # run population A and determine which neurons spike\n",
    "    spikes_A = run_neurons(input_A, v_A, ref_A)\n",
    "\n",
    "    # decay all of the inputs (implementing the post-synaptic filter)\n",
    "    for j in range(N_B):\n",
    "        input_B[j] *= (1.0 - pstc_scale)\n",
    "    # for each neuron that spikes, increase the input current\n",
    "    # of all the neurons it is connected to by the synaptic\n",
    "    # connection weight\n",
    "    for i, s in enumerate(spikes_A):\n",
    "        if s:\n",
    "            for j in range(N_B):\n",
    "                input_B[j] += weights[i][j] * pstc_scale\n",
    "\n",
    "    # compute the total input into each neuron in population B\n",
    "    # (taking into account gain and bias)\n",
    "    total_B = [0] * N_B\n",
    "    for j in range(N_B):\n",
    "        total_B[j] = gain_B[j] * input_B[j] + bias_B[j]\n",
    "\n",
    "    # run population B and determine which neurons spike\n",
    "    spikes_B = run_neurons(total_B, v_B, ref_B)\n",
    "\n",
    "    # for each neuron in B that spikes, update our decoded value\n",
    "    # (also applying the same post-synaptic filter)\n",
    "    output *= (1.0 - pstc_scale)\n",
    "    for j, s in enumerate(spikes_B):\n",
    "        if s:\n",
    "            output += decoder_B[j][0] * pstc_scale\n",
    "\n",
    "    if t % 0.5 <= dt:\n",
    "        print(t, output)\n",
    "    times.append(t)\n",
    "    inputs.append(x)\n",
    "    outputs.append(output)\n",
    "    ideal.append(function(x))\n",
    "    t += dt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Plot the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x, A = compute_tuning_curves(encoder_A, gain_A, bias_A)\n",
    "x, B = compute_tuning_curves(encoder_B, gain_B, bias_B)\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(x, A)\n",
    "plt.title('Tuning curves for population A')\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(x, B)\n",
    "plt.title('Tuning curves for population B')\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(times, inputs, label='input')\n",
    "plt.plot(times, ideal, label='ideal')\n",
    "plt.plot(times, outputs, label='output')\n",
    "plt.title('Simulation results')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
